{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b7973eef-7e5e-4ca4-844a-1d016f5a638b",
   "metadata": {},
   "source": [
    "# Building a Custom Agent\n",
    "\n",
    "In this cookbook we show you how to build a custom agent using LlamaIndex.\n",
    "\n",
    "The easiest way to build a custom agent is to simply subclass `CustomSimpleAgentWorker` and implement a few required functions. You have complete flexibility in defining the agent step-wise logic.\n",
    "\n",
    "This lets you add arbitrarily complex reasoning logic on top of your RAG pipeline.\n",
    "\n",
    "We show you how to build a simple agent that adds a retry layer on top of a RouterQueryEngine, allowing it to retry queries until the task is complete. We build this on top of both a SQL tool and a vector index query tool. Even if the tool makes an error or only answers part of the question, the agent can continue retrying the question until the task is complete."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0fcf8f1d-9e41-434e-82a1-75471eb85275",
   "metadata": {},
   "source": [
    "## Setup the Custom Agent\n",
    "\n",
    "Here we setup the custom agent.\n",
    "\n",
    "### Refresher\n",
    "\n",
    "An agent in LlamaIndex consists of both an agent runner + agent worker. An agent runner is an orchestrator that stores state like memory, whereas an agent worker controls the step-wise execution of a Task. Agent runners include sequential, parallel execution. More details can be found in our [lower level API guide](https://docs.llamaindex.ai/en/latest/module_guides/deploying/agents/agent_runner.html).\n",
    "\n",
    "Most core agent logic (e.g. ReAct, function calling loops), can be executed in the agent worker. Therefore we've made it easy to subclass an agent worker, letting you plug it into any agent runner.\n",
    "\n",
    "### Creating a Custom Agent Worker Subclass\n",
    "\n",
    "As mentioned above we subclass `CustomSimpleAgentWorker`. This is a class that already sets up some scaffolding for you. This includes being able to take in tools, callbacks, LLM, and also ensures that the state/steps are properly formatted. In the meantime you mostly have to implement the following functions:\n",
    "\n",
    "- `_initialize_state`\n",
    "- `_run_step`\n",
    "- `_finalize_task`\n",
    "\n",
    "Some additional notes:\n",
    "- You can implement `_arun_step` as well if you want to support async chat in the agent.\n",
    "- You can choose to override `__init__` as long as you pass all remaining args, kwargs to `super()`\n",
    "- `CustomSimpleAgentWorker` is implemented as a Pydantic `BaseModel` meaning that you can define your own custom properties as well.\n",
    "\n",
    "Here are the full set of base properties on each `CustomSimpleAgentWorker` (that you need to/can pass in when constructing your custom agent):\n",
    "- `tools: Sequence[BaseTool]`\n",
    "- `tool_retriever: Optional[ObjectRetriever[BaseTool]]`\n",
    "- `llm: LLM`\n",
    "- `callback_manager: CallbackManager`\n",
    "- `verbose: bool`\n",
    "\n",
    "Note that `tools` and `tool_retriever` are mutually exclusive, you can only pass in one or the either (e.g. define a static list of tools or define a callable function that returns relevant tools given a user message). You can call `get_tools(message: str)` to return relevant tools given a message.\n",
    "\n",
    "All of these properties are accessible via `self` when defining your custom agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aacdcad-f0c1-40f4-b319-6c4cf3b309c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.agent import CustomSimpleAgentWorker, Task, AgentChatResponse\n",
    "from typing import Dict, Any, List, Tuple, Optional\n",
    "from llama_index.tools import BaseTool, QueryEngineTool\n",
    "from llama_index.program import LLMTextCompletionProgram\n",
    "from llama_index.output_parsers import PydanticOutputParser\n",
    "from llama_index.query_engine import RouterQueryEngine\n",
    "from llama_index.prompts import ChatPromptTemplate, PromptTemplate\n",
    "from llama_index.selectors import PydanticSingleSelector\n",
    "from pydantic import Field, BaseModel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b5280cc-1a53-4aba-9745-03b9f9fcd864",
   "metadata": {},
   "source": [
    "Here we define some helper variables and methods. E.g. the prompt template to use to detect errors as well as the response format in Pydantic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91db5f4f-06d8-46a0-af3e-cfa9b7c096f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.llms import ChatMessage, MessageRole\n",
    "\n",
    "DEFAULT_PROMPT_STR = \"\"\"\n",
    "Given previous question/response pairs, please determine if an error has occurred in the response, and suggest \\\n",
    "    a modified question that will not trigger the error.\n",
    "\n",
    "Examples of modified questions:\n",
    "- The question itself is modified to elicit a non-erroneous response\n",
    "- The question is augmented with context that will help the downstream system better answer the question.\n",
    "- The question is augmented with examples of negative responses, or other negative questions.\n",
    "\n",
    "An error means that either an exception has triggered, or the response is completely irrelevant to the question.\n",
    "\n",
    "Please return the evaluation of the response in the following JSON format.\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "def get_chat_prompt_template(\n",
    "    system_prompt: str, current_reasoning: Tuple[str, str]\n",
    ") -> ChatPromptTemplate:\n",
    "    system_msg = ChatMessage(role=MessageRole.SYSTEM, content=system_prompt)\n",
    "    messages = [system_msg]\n",
    "    for raw_msg in current_reasoning:\n",
    "        if raw_msg[0] == \"user\":\n",
    "            messages.append(\n",
    "                ChatMessage(role=MessageRole.USER, content=raw_msg[1])\n",
    "            )\n",
    "        else:\n",
    "            messages.append(\n",
    "                ChatMessage(role=MessageRole.ASSISTANT, content=raw_msg[1])\n",
    "            )\n",
    "    return ChatPromptTemplate(message_templates=messages)\n",
    "\n",
    "\n",
    "class ResponseEval(BaseModel):\n",
    "    \"\"\"Evaluation of whether the response has an error.\"\"\"\n",
    "\n",
    "    has_error: bool = Field(\n",
    "        ..., description=\"Whether the response has an error.\"\n",
    "    )\n",
    "    new_question: str = Field(..., description=\"The suggested new question.\")\n",
    "    explanation: str = Field(\n",
    "        ...,\n",
    "        description=(\n",
    "            \"The explanation for the error as well as for the new question.\"\n",
    "            \"Can include the direct stack trace as well.\"\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1421133e-5091-425b-822c-c2c0b8f084c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import PrivateAttr\n",
    "\n",
    "\n",
    "class RetryAgentWorker(CustomSimpleAgentWorker):\n",
    "    \"\"\"Agent worker that adds a retry layer on top of a router.\n",
    "\n",
    "    Continues iterating until there's no errors / task is done.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    prompt_str: str = Field(default=DEFAULT_PROMPT_STR)\n",
    "    max_iterations: int = Field(default=10)\n",
    "\n",
    "    _router_query_engine: RouterQueryEngine = PrivateAttr()\n",
    "\n",
    "    def __init__(self, tools: List[BaseTool], **kwargs: Any) -> None:\n",
    "        \"\"\"Init params.\"\"\"\n",
    "        # validate that all tools are query engine tools\n",
    "        for tool in tools:\n",
    "            if not isinstance(tool, QueryEngineTool):\n",
    "                raise ValueError(\n",
    "                    f\"Tool {tool.metadata.name} is not a query engine tool.\"\n",
    "                )\n",
    "        self._router_query_engine = RouterQueryEngine(\n",
    "            selector=PydanticSingleSelector.from_defaults(),\n",
    "            query_engine_tools=tools,\n",
    "            verbose=kwargs.get(\"verbose\", False),\n",
    "        )\n",
    "        super().__init__(\n",
    "            tools=tools,\n",
    "            **kwargs,\n",
    "        )\n",
    "\n",
    "    def _initialize_state(self, task: Task, **kwargs: Any) -> Dict[str, Any]:\n",
    "        \"\"\"Initialize state.\"\"\"\n",
    "        return {\"count\": 0, \"current_reasoning\": []}\n",
    "\n",
    "    def _run_step(\n",
    "        self, state: Dict[str, Any], task: Task, input: Optional[str] = None\n",
    "    ) -> Tuple[AgentChatResponse, bool]:\n",
    "        \"\"\"Run step.\n",
    "\n",
    "        Returns:\n",
    "            Tuple of (agent_response, is_done)\n",
    "\n",
    "        \"\"\"\n",
    "        if input is not None:\n",
    "            # if input is specified, override input\n",
    "            new_input = input\n",
    "        elif \"new_input\" not in state:\n",
    "            new_input = task.input\n",
    "        else:\n",
    "            new_input = state[\"new_input\"]\n",
    "\n",
    "        if self.verbose:\n",
    "            print(f\"> Current Input: {new_input}\")\n",
    "\n",
    "        # first run router query engine\n",
    "        response = self._router_query_engine.query(new_input)\n",
    "\n",
    "        # append to current reasoning\n",
    "        state[\"current_reasoning\"].extend(\n",
    "            [(\"user\", new_input), (\"assistant\", str(response))]\n",
    "        )\n",
    "\n",
    "        # Then, check for errors\n",
    "        # dynamically create pydantic program for structured output extraction based on template\n",
    "        chat_prompt_tmpl = get_chat_prompt_template(\n",
    "            self.prompt_str, state[\"current_reasoning\"]\n",
    "        )\n",
    "        llm_program = LLMTextCompletionProgram.from_defaults(\n",
    "            output_parser=PydanticOutputParser(output_cls=ResponseEval),\n",
    "            prompt=chat_prompt_tmpl,\n",
    "            llm=self.llm,\n",
    "        )\n",
    "        # run program, look at the result\n",
    "        response_eval = llm_program(\n",
    "            query_str=new_input, response_str=str(response)\n",
    "        )\n",
    "        if not response_eval.has_error:\n",
    "            is_done = True\n",
    "        else:\n",
    "            is_done = False\n",
    "        state[\"new_input\"] = response_eval.new_question\n",
    "\n",
    "        if self.verbose:\n",
    "            print(f\"> Question: {new_input}\")\n",
    "            print(f\"> Response: {response}\")\n",
    "            print(f\"> Response eval: {response_eval.dict()}\")\n",
    "\n",
    "        # return response\n",
    "        return AgentChatResponse(response=str(response)), is_done\n",
    "\n",
    "    def _finalize_task(self, state: Dict[str, Any], **kwargs) -> None:\n",
    "        \"\"\"Finalize task.\"\"\"\n",
    "        # nothing to finalize here\n",
    "        # this is usually if you want to modify any sort of\n",
    "        # internal state beyond what is set in `_initialize_state`\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60740ffe-8791-4723-b0e8-2d67487a2e84",
   "metadata": {},
   "source": [
    "## Setup Data and Tools\n",
    "\n",
    "We setup both a SQL Tool as well as vector index tools for each city."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0146889a-0c15-48a0-8fe0-e4434c32f17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.tools.query_engine import QueryEngineTool"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff905ae1-a068-4d2e-8ad3-44d625aa0947",
   "metadata": {},
   "source": [
    "### Setup SQL DB + Tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b866150c-deab-4ba6-b735-a77f662d33ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    MetaData,\n",
    "    Table,\n",
    "    Column,\n",
    "    String,\n",
    "    Integer,\n",
    "    select,\n",
    "    column,\n",
    ")\n",
    "from llama_index import SQLDatabase\n",
    "\n",
    "engine = create_engine(\"sqlite:///:memory:\", future=True)\n",
    "metadata_obj = MetaData()\n",
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbba0868-198a-417d-82cd-0d36911b7172",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import insert\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\"city_name\": \"Berlin\", \"population\": 3645000, \"country\": \"Germany\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.begin() as connection:\n",
    "        cursor = connection.execute(stmt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1639ad45-7c4b-4028-a192-e5d2fd64afcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine\n",
    "\n",
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])\n",
    "sql_query_engine = NLSQLTableQueryEngine(\n",
    "    sql_database=sql_database, tables=[\"city_stats\"], verbose=True\n",
    ")\n",
    "sql_tool = QueryEngineTool.from_defaults(\n",
    "    query_engine=sql_query_engine,\n",
    "    description=(\n",
    "        \"Useful for translating a natural language query into a SQL query over\"\n",
    "        \" a table containing: city_stats, containing the population/country of\"\n",
    "        \" each city\"\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fab3526c-f73a-4403-b1a4-5893e85a88af",
   "metadata": {},
   "source": [
    "### Setup Vector Tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e32ccc-2aca-4884-9da2-bf20a07351b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.readers import WikipediaReader\n",
    "from llama_index import VectorStoreIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86e2ff7c-a891-48df-a153-3dcaead671d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "cities = [\"Toronto\", \"Berlin\", \"Tokyo\"]\n",
    "wiki_docs = WikipediaReader().load_data(pages=cities)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9377221-dab4-4a8e-b5d1-3d76d5bb976d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# build a separate vector index per city\n",
    "# You could also choose to define a single vector index across all docs, and annotate each chunk by metadata\n",
    "vector_tools = []\n",
    "for city, wiki_doc in zip(cities, wiki_docs):\n",
    "    vector_index = VectorStoreIndex.from_documents([wiki_doc])\n",
    "    vector_query_engine = vector_index.as_query_engine()\n",
    "    vector_tool = QueryEngineTool.from_defaults(\n",
    "        query_engine=vector_query_engine,\n",
    "        description=f\"Useful for answering semantic questions about {city}\",\n",
    "    )\n",
    "    vector_tools.append(vector_tool)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62b3286f-ef5c-4c6c-8764-0b42e6181345",
   "metadata": {},
   "source": [
    "## Build Custom Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db3992b5-f433-44a5-8d97-fb0502143c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.agent import AgentRunner\n",
    "from llama_index.llms import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7824b9d2-185c-4d79-bdaa-c227b4a004f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(model=\"gpt-4\")\n",
    "callback_manager = llm.callback_manager\n",
    "\n",
    "query_engine_tools = [sql_tool] + vector_tools\n",
    "agent_worker = RetryAgentWorker.from_tools(\n",
    "    query_engine_tools,\n",
    "    llm=llm,\n",
    "    verbose=True,\n",
    "    callback_manager=callback_manager,\n",
    ")\n",
    "agent = AgentRunner(agent_worker, callback_manager=callback_manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0d6798c-a33c-455a-b4e4-fb0bb4208435",
   "metadata": {},
   "source": [
    "## Try Out Some Queries\n",
    "\n",
    "Let's run some e2e queries through `agent.chat`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28ae5efa-378b-4e0c-9b0b-72a460ada9fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Current Input: Which countries are each city from?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: The first choice is the most relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n",
      "\u001b[0m> Question: Which countries are each city from?\n",
      "> Response: The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n",
      "> Response eval: {'has_error': True, 'new_question': 'Which country is each of the following cities from: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague and did not specify which cities the user was interested in. The new question provides specific cities for the system to provide information on.'}\n",
      "> Current Input: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n",
      "\u001b[0m> Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n",
      "> Response: Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n",
      "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n",
      "Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"Which countries are each city from?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7300fe97-deb3-44ca-a375-2c14b173d17e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Current Input: What are the top modes of transporation fo the city with the higehest population?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: The question is asking about the top modes of transportation for the city with the highest population, which requires translating a natural language query into a SQL query over a table containing city statistics..\n",
      "\u001b[0m> Question: What are the top modes of transporation fo the city with the higehest population?\n",
      "> Response: I'm sorry, but there seems to be an error in the SQL query. Please check the syntax and try again.\n",
      "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for the city with the highest population?', 'explanation': 'The original question had spelling errors which might have caused confusion. The corrected question now clearly asks for the top modes of transportation in the city with the highest population.'}\n",
      "> Current Input: What are the top modes of transportation for the city with the highest population?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions translating a natural language query into a SQL query over a table containing city statistics, which could include information about the population of cities..\n",
      "\u001b[0m> Question: What are the top modes of transportation for the city with the highest population?\n",
      "> Response: The city with the highest population is Tokyo, Japan.\n",
      "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation in Tokyo, Japan?', 'explanation': 'The assistant did not answer the original question correctly. The question asked for the top modes of transportation in the city with the highest population, but the assistant only provided the city with the highest population. The new question directly asks for the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n",
      "> Current Input: What are the top modes of transportation in Tokyo, Japan?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 3: Tokyo is mentioned in choice 4, which is specifically about answering semantic questions about Tokyo..\n",
      "\u001b[0m> Question: What are the top modes of transportation in Tokyo, Japan?\n",
      "> Response: The top modes of transportation in Tokyo, Japan are trains and subways, which are considered clean and efficient. There are also buses, monorails, and trams that play a secondary role in the transportation system. Additionally, there are expressways that connect Tokyo to other points in the Greater Tokyo Area and beyond. Taxis and long-distance ferries are also available for transportation within the city and to other islands.\n",
      "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation in Tokyo, Japan?', 'explanation': 'The original question was not answered correctly. The assistant provided information about the city with the highest population, but did not answer the question about the top modes of transportation in that city. The new question directly asks about the modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n",
      "> Current Input: What are the top modes of transportation in Tokyo, Japan?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 3: The question specifically asks about Tokyo, and choice 4 is about answering semantic questions about Tokyo..\n",
      "\u001b[0m> Question: What are the top modes of transportation in Tokyo, Japan?\n",
      "> Response: The top modes of transportation in Tokyo, Japan are trains and subways. Tokyo has an extensive network of clean and efficient trains and subways operated by various operators. There are up to 62 electric train lines and more than 900 train stations in Tokyo. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has expressways, taxis, and long-distance ferries as other means of transportation.\n",
      "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n",
      "The top modes of transportation in Tokyo, Japan are trains and subways. Tokyo has an extensive network of clean and efficient trains and subways operated by various operators. There are up to 62 electric train lines and more than 900 train stations in Tokyo. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has expressways, taxis, and long-distance ferries as other means of transportation.\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\n",
    "    \"What are the top modes of transporation fo the city with the higehest population?\"\n",
    ")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "926b79ba-1868-46ca-bb7e-2bd3c907773c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Current Input: What are the sports teams of each city in Asia?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 3: The question is asking about sports teams in Asia, and Tokyo is a city in Asia..\n",
      "\u001b[0m> Question: What are the sports teams of each city in Asia?\n",
      "> Response: I'm sorry, but the context information does not provide a comprehensive list of sports teams in each city in Asia. It only mentions some of the sports teams in Tokyo, Japan. To answer your question, I would need more specific information or a different source that provides a broader overview of sports teams in cities across Asia.\n",
      "> Response eval: {'has_error': True, 'new_question': 'What are some popular sports teams in Tokyo, Japan?', 'explanation': 'The original question is too broad and the system does not have enough information to provide a comprehensive list of sports teams in each city in Asia. The new question is more specific and focuses on sports teams in Tokyo, Japan, which the system has information on.'}\n",
      "> Current Input: What are some popular sports teams in Tokyo, Japan?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 3: Tokyo is mentioned in choice 4, which is about answering semantic questions about Tokyo..\n",
      "\u001b[0m> Question: What are some popular sports teams in Tokyo, Japan?\n",
      "> Response: Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and the Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the Ryōgoku Kokugikan sumo arena.\n",
      "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n",
      "Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and the Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the Ryōgoku Kokugikan sumo arena.\n"
     ]
    }
   ],
   "source": [
    "response = agent.chat(\"What are the sports teams of each city in Asia?\")\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a89391a4-8836-432a-a42f-9c831710f2e4",
   "metadata": {},
   "source": [
    "## Step-wise Queries\n",
    "\n",
    "We can also try some step-wise queries. This allows us to inject user feedback in the middle of a task execution to guide responses towards the correct state faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a4b1a2-5c7f-494f-bb35-a472dde47eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_worker = RetryAgentWorker.from_tools(\n",
    "    query_engine_tools,\n",
    "    llm=llm,\n",
    "    verbose=True,\n",
    "    callback_manager=callback_manager,\n",
    ")\n",
    "agent = AgentRunner(agent_worker, callback_manager=callback_manager)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9e51266-9b29-48e7-b40f-cc45e7f8dd30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Current Input: Which countries are each city from?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: The first choice is the most relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n",
      "\u001b[0m> Question: Which countries are each city from?\n",
      "> Response: The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n",
      "> Response eval: {'has_error': True, 'new_question': 'Can you tell me the country of origin for each of these cities: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague and did not specify which cities the user was interested in. The new question provides specific cities for the system to provide information on.'}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AgentChatResponse(response='The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.', sources=[], source_nodes=[])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "task = agent.create_task(\"Which countries are each city from?\")\n",
    "step_output = agent.run_step(task.task_id)\n",
    "step_output.output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe5c76f-b183-43c8-b965-5ce576a70be3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Current Input: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n",
      "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n",
      "\u001b[0m> Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n",
      "> Response: Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n",
      "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AgentChatResponse(response='Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.', sources=[], source_nodes=[])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "step_output = agent.run_step(\n",
    "    task.task_id,\n",
    "    input=\"Which country is each of the following cities from: Toronto, Tokyo, Berlin?\",\n",
    ")\n",
    "step_output.output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb30f0a-5947-456b-ab05-5db6d76bca7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n"
     ]
    }
   ],
   "source": [
    "if step_output.is_last:\n",
    "    response = agent.finalize_response(task.task_id)\n",
    "    print(str(response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
